% ------------------Variational Soft Symbol Decoding (VSSD)----------------
%
% Bayesian algorithm to decode the soft symbols from the measurements y =
% Hs + w, where,
%
% y: NLx1 vector of measurements (L>=2, to satisfy Nyquist sampling)
% s: 2Nx1 vector consisting of the real and imaginary parts of the N
%    (complex) symbols
% H: NLx2N real valued channel matrix
% w: NLx1 Guassian Noise vector with zero mean vector and covariance
%    sigma^2 I_{NL}
%
% In this code, we compare the performance of the newly proposed VSSD with
% the MMSE decoder for the (benchmark) IID Gausssian MIMO channel. The
% function qSolver.m performs fixed point iterations for VSSD.
%
%References
%   1) Arunkumar K.P. and Chandra R. Murthy, "Variational Soft Symbol
%   Decoding for Sweep Spread Carrier Based Underwater Acoustic
%   Communication", SPAWC 2019, Cannes, France.
%
%   2) Arunkumar K. P. and C. R. Murthy, Soft Symbol Decoding in
%   Sweep-Spread-Carrier Underwater Acoustic Communications: A Novel
%   Variational Bayesian Algorithm and its Analysis, Accepted, IEEE
%   Transactions on Signal Processing, Mar. 2020.
%
%Author  : Arunkumar K. P.
%Address : Ph.D. Scholar,
%          Signal Processing for Communications Lab, ECE Department,
%          Indian Institute of Science, Bangalore, India-560 012.
%Email   : arunkumar@iisc.ac.in
%
%
%Revision History
% Version : 1.2
% Last Revision: 15-03-2020
%
%
% This script/program is released under the Commons Creative Licence
% with Attribution Non-commercial Share Alike (by-nc-sa)
% http://creativecommons.org/licenses/by-nc-sa/3.0/
%
%
%++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
% Short Disclaimer: this script is for educational purpose only.
%++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

clear;

clc;

close all;

defpaper;

%% Define parameters
N = 288; %number of complex symbols

deltaH = 0;%0.05;%0; % channel estimation error

L = 2; %Oversampling factor; L = 2 corresponds to Nyquist rate

mQAM = 4; %QPSK

channelModel = 'IID Gaussian MIMO';

NBITERRMIN = 200; % # of bit errors to wait for until closing the Monte Carlo run

% IID Gaussian
SNRs =  [0 3 6 9 12 14 15];%i.i.d. Gaussian MIMO channel, VSSD
% SNRs =  [0 3 6 9 12 15 18 21 24 27];%i.i.d. Gaussian MIMO channel, MMSE

NtrialMax = 10^5;%1000; %number of Monte Carlo runs

NtrialMin = 1;

Niters = 100; %maximum number of fixed point iterations for VSSD

displayRate = 30/60; %results are displayed at displayRate (minutes)

decoderList = 'MMSE, VSSD';

%% Modulator and demodulator - initialization
%% Bit to symbol mapper and symbol to bit demapper
%Create a QPSK modulator and demodulator pair, where the demodulator
%outputs soft bits determined by using a log-likelihood ratio method. The
%modulator and demodulator objects are normalized to use an average power
%of 1 W.
hMod = comm.RectangularQAMModulator('ModulationOrder',mQAM, ...
    'BitInput',true,'NormalizationMethod','Average power');

hDemod = comm.RectangularQAMDemodulator('ModulationOrder',mQAM, ...
    'BitOutput',true,'NormalizationMethod','Average power');

frmLen = N*log2(mQAM); %uncoded frame length
    
%% Define array variables to store results for each decoder
% MMSE variables & arrays...
%ber_MMSE = zeros(Ntrials,length(SNRs));
BER_MMSE = zeros(length(SNRs),1);
flagMMSE = NaN;

% VSSD variables and arrays...
%ber_VSSD = zeros(Ntrials,length(SNRs));
BER_VSSD = zeros(length(SNRs),1);
avgIters_VSSD = zeros(length(SNRs),1);
flagVSSD = NaN;

%flagMMSE, flagVSSD are used for flagging the active/inactive status of
%the respective decoders

itersSNR = zeros(length(SNRs),1); %number iterations performed for each SNR

%% Main loop below
tSaver  = tic; % start the timer for saving results

tDispl  = tic; % start the timer for displaying results

for iSNR = 1:length(SNRs)
    
    SNR = SNRs(iSNR); %next SNR value
    
    for trial = 1: NtrialMax
        %% Measurement model
        % Generate random binary data
        sendBits = logical(randi([0 1], frmLen, 1));
        
        % Modulate the encoded data
        sQAM = step(hMod, sendBits); %M-QAM symbols: should be Npuls*(K-1) x 1
                
        % Real & Imag parts stripped out symbol vector, s
        s = zeros(2*N,1);        
        s(1:2:end) = real(sQAM);        
        s(2:2:end) = imag(sQAM);
        
        % Generate channel matrix, H: IID Gaussian
        H = randn(L*N, 2*N); %entires of the channel matrix are i.i.d. Gaussian
        Ps = N^2*L; %signal power, E (||Hs||^2) (H ~ IID Guassian MIMO channel)        
        
        % Generate the signal part of the measurements
        x = H*s;
        
        % Add estimation error to perfect CSI
        H = H + deltaH*randn(size(H)); %imperfect CSIR, deltaH controls the error
        
        % Generate (additive) noise part
        Pn = Ps*10^(-SNR/10)/(N*L); %noise power such that SNR is as specified
        
        w = sqrt(Pn)*randn(N*L,1); %AWGN
        
        % Generate Measurements, y
        y = x + w;        
        
        %% (1) MMSE Decoder
        if ~isempty(regexpi(decoderList,'MMSE'))
            flagMMSE = true; %set this decoder as active
            
            % Received symbols
            sRX = (H'*H+2*Pn*eye(2*N))\(H'*y);%MMSE Solution            
            sRX_ = sRX(1:2:end) + 1i*sRX(2:2:end); %complex form
            
            % Demodulate the received signal
            receivedBits = step(hDemod, sRX_(:));% demodulator outputs hard-bit %, Pn);%1);%noiseVar(snrIdx)
            
            %         ber_MMSE(trial, iSNR) = mean(abs(receivedBits-sendBits));
            BER_MMSE(iSNR) = ( BER_MMSE(iSNR)*(trial-1) ...
                + mean(abs(receivedBits-sendBits)) ) / trial;
        end
        
        %% (2) VSSD
        if ~isempty(regexpi(decoderList,'VSSD'))
            flagVSSD = true; %set this decoder as active            
            
            q0 = double(gt((H'*H+2*Pn*eye(2*N))\(H'*y),0)); %initial soft-symbol vector               
            qTol = 1e-3;
            [ q, iter, qTol, ELBO, A ] = qSolver( q0, H, y, Pn, qTol, Niters);
                        
            avgIters_VSSD(iSNR) = ( avgIters_VSSD(iSNR)*(trial-1) + iter ) / trial;            
                        
            % Received symbols            
            sRX_ = (q(1:2:end)-0.5 + 1i*(q(2:2:end)-0.5))*sqrt(2); %complex form
                        
            % Demodulate the received signal
            receivedBits = step(hDemod, sRX_(:));% demodulator outputs hard-bit %, Pn);%1);%noiseVar(snrIdx)
            
            %         ber_VSSD(trial, iSNR) = mean(abs(receivedBits-sendBits));
            BER_VSSD(iSNR) = ( BER_VSSD(iSNR)*(trial-1) ...
                + mean(abs(receivedBits-sendBits)) ) / trial;
        end
                        
        %% Quit Monte-Carlo loop if atleast 100 bit errors are registered for each active decoder
        minBitErr =  min([BER_MMSE(iSNR)*flagMMSE...
                BER_VSSD(iSNR)*flagVSSD]*trial*length(sendBits));
        if (minBitErr > NBITERRMIN) && (trial > NtrialMin)
            disp(' ');
            disp([num2str(trial) ': ' num2str([BER_MMSE(iSNR)*flagMMSE...
                BER_VSSD(iSNR)*flagVSSD])...
                ' (SNR = ' num2str(SNR) ' dB)' ' [bitErrs = ' num2str(minBitErr) ']']);
            disp(['Monte-Carlo loop terminated since total bit error count > '...
                num2str(NBITERRMIN) ' for all active decoders and # of trials='...
                num2str(trial) ' > ' num2str(NtrialMin)]);
            break;
        end
        
        %% Display results (in the run time)
        if toc(tDispl) > displayRate*60
            %Save the workspace into a mat file every saveRate minutes
            tDispl  = tic; % restart the timer for saving results
            disp(' ');
            disp([num2str(trial) ': ' num2str([BER_MMSE(iSNR)*flagMMSE...
                BER_VSSD(iSNR)*flagVSSD])...
                ' (SNR = ' num2str(SNR) ' dB)' ' [bitErrs = ' num2str(minBitErr) ']']);            
            %Update the plot every displayRate minutes            
            figure(1), cla, berPlots;
            figure(2), cla, bar(SNRs(1:iSNR), avgIters_VSSD(1:iSNR), 0.4), grid on            
            xlabel('SNR (dB)'), ylabel('Average number of iterations')            
            xlim([SNRs(1)-3/4 SNRs(iSNR)+3/4])            
            
            pause(0.1) %waiting just a bit to get the plot window updated
        end        
        
    end
    
    itersSNR(iSNR) = trial; %save the number of iterations performed for each SNR
    
    figure(1), cla, berPlots; %plot BER curves        
   
    figure(2), cla, bar(SNRs(1:iSNR), avgIters_VSSD(1:iSNR), 0.4), grid on    
    xlabel('SNR (dB)'), ylabel('Average number of iterations')    
    xlim([SNRs(1)-3/4 SNRs(iSNR)+3/4])
    
    pause(0.1);    %waiting just a bit to get the plot window updated
    
end